import numpy as np
import matplotlib.pyplot as plt
import os
import pandas as pd

os.environ["OMP_NUM_THREADS"] = '4'

from sklearn.decomposition import PCA
from sklearn.neighbors import NearestNeighbors
import scipy
from scipy.spatial.distance import pdist, squareform

def neighbor_kept_ratio_eval(X, X_new, n_neighbors=30):
    '''
    This is a function that evaluates the local structure preservation.
    A nearest neighbor set is constructed on both the high dimensional space and
    the low dimensional space.
    Input:
        X: A numpy array with the shape [N, p]. The higher dimension embedding
           of some dataset. Expected to have some clusters.
        X_new: A numpy array with the shape [N, k]. The lower dimension embedding
               of some dataset. Expected to have some clusters as well.
        y: A numpy array with the shape [N, 1]. The labels of the original
           dataset. Used to identify clusters
    Output:
        acc: The score generated by the algorithm.
        
    Source: https://github.com/hyhuang00/scRNA-DR2020/blob/main/experiments/run_eval.py
    '''
    nn_hd = NearestNeighbors(n_neighbors=n_neighbors+1)
    nn_ld = NearestNeighbors(n_neighbors=n_neighbors+1)
    nn_hd.fit(X)
    nn_ld.fit(X_new)
    # Construct a k-neighbors graph, where 1 indicates a neighbor relationship
    # and 0 means otherwise, resulting in a graph of the shape n * n
    graph_hd = nn_hd.kneighbors_graph(X).toarray()
    graph_hd -= np.eye(X.shape[0]) # Removing diagonal
    graph_ld = nn_ld.kneighbors_graph(X_new).toarray()
    graph_ld -= np.eye(X.shape[0]) # Removing diagonal
    neighbor_kept = np.sum(graph_hd * graph_ld).astype(float)
    neighbor_kept_ratio = neighbor_kept / n_neighbors / X.shape[0]
    return neighbor_kept_ratio

def global_score(X_high, X_low):
    """
    This is a function that evaluates the global structure preservation through the Global Score metric.

    Input:
        X: Instance matrix
        Y: Embedding
    Output:
        Global score
        
    Source: https://github.com/eamid/trimap/blob/master/trimap/trimap_.py#L869
    """

    def global_loss_(X, Y):
        X = X - np.mean(X, axis=0)
        Y = Y - np.mean(Y, axis=0)
        A = X.T @ (Y @ np.linalg.inv(Y.T @ Y))
        return np.mean(np.power(X.T - A @ Y.T, 2))

    n_dims = X_low.shape[1]
    Y_pca = PCA(n_components=n_dims).fit_transform(X_high)
    gs_pca = global_loss_(X_high, Y_pca)
    gs_emb = global_loss_(X_high, X_low)
    return np.exp(-(gs_emb - gs_pca) / gs_pca)


def compute_triplet_accuracy(X_high, X_low, triplets):
    if len(triplets) == 0:
        return -1
    
    high_d_triplets = [(X_high[i], X_high[j], X_high[k]) for i, j, k in triplets]
    low_d_triplets = [(X_low[i], X_low[j], X_low[k]) for i, j, k in triplets]
    relative_d_high = [np.sign(np.linalg.norm(i - j) - np.linalg.norm(i - k)) for i, j, k in high_d_triplets]
    relative_d_low = [np.sign(np.linalg.norm(i - j) - np.linalg.norm(i - k)) for i, j, k in low_d_triplets]
    acc = np.mean(np.array(relative_d_high) == np.array(relative_d_low))
    
    return acc

"""
def compute_triplet_accuracy(d_high, d_low, triplets):
    if len(triplets) == 0:
        return np.nan
    # Indexes do not correspond! Need to change the way we index the triplets
    relative_d_high = [np.sign(d_high[i, j] - d_high[i, k]) for i, j, k in triplets]
    relative_d_low = [np.sign(d_low[i, j] - d_low[i, k]) for i, j, k in triplets]
    acc = np.mean(np.array(relative_d_high) == np.array(relative_d_low))
    
    return acc
"""

def random_triplet_accuracy(X_high, X_low, neighbors_high, n_triplets=1000, n_repetitions=10, neighborhood_size=10):
    """
    Compute the Random Triplet Accuracy metric, which measures the proportion of triplets 
    for which the relative distances are preserved in the embedding compared to the original data.

    Args:

    Returns:
        tuple: mean and standard deviation of the Random Triplet Accuracy
    """
    
    local_accs = []
    global_accs = []
    accs = []
    for _ in range(n_repetitions):
        triplets = [np.random.randint(0, X_low.shape[0], 3) for _ in range(n_triplets)]
        
        # Local triplets are triplets where at least one of j or k is in the neighborhood of i
        local_triplets = [(i,j,k) for i,j,k in triplets if j in neighbors_high[i] or k in neighbors_high[i]]
        acc = compute_triplet_accuracy(X_high, X_low, local_triplets)
        local_accs.append(acc)
        
        # Global triplets are triplets where j and k are both outside the neighborhood of i
        global_triplets = [(i,j,k) for i,j,k in triplets if j not in neighbors_high[i] and k not in neighbors_high[i]]
        
        acc = compute_triplet_accuracy(X_high, X_low, global_triplets)
        global_accs.append(acc)
        
        # All triplets
        acc = compute_triplet_accuracy(X_high, X_low, triplets)
        accs.append(acc)

    local_mean_acc = np.mean(local_accs)
    local_std_acc = np.std(local_accs)
    
    global_mean_acc = np.mean(global_accs)
    global_std_acc = np.std(global_accs)
    
    mean_acc = np.mean(accs)
    std_acc = np.std(accs)

    rta = {
        "local": (local_mean_acc, local_std_acc),
        "global": (global_mean_acc, global_std_acc),
        "all": (mean_acc, std_acc)
    }
    
    return rta


def spearman_correlation_eval(X_high, X_low, n_points=1000, n_repetitions=10):
    '''
    Adapted from https://github.com/hyhuang00/scRNA-DR2020/blob/main/experiments/run_eval.py
    '''

    corrs = []
    pvals = []
    for _ in range(n_repetitions):
        # Sample n_points points from the dataset randomly
        sample_index = np.random.choice(X_high.shape[0], n_points, replace=False)
        
        # Generate the distance matrix in high dim and low dim
        dist_high = squareform(pdist(X_high[sample_index]))
        dist_low = squareform(pdist(X_low[sample_index]))

        # Calculate the correlation
        corr, pval = scipy.stats.spearmanr(dist_high, dist_low)
        corrs.append(corr)
        pvals.append(pval)
        
    corr_mean = np.mean(corrs)
    corr_std = np.std(corrs)
    
    return corr_mean, corr_std


def run_eval(X_high, X_low, 
             neighbors_high,
             n_sample = 100,
             neighborhood_size = 50,
             n_repetitions = 100):
    
    nkt = neighbor_kept_ratio_eval(X_high, X_low)
    gs = global_score(X_high, X_low)
    rta = random_triplet_accuracy(X_high, X_low, neighbors_high, n_triplets=n_sample, n_repetitions=n_repetitions, neighborhood_size=neighborhood_size)
    dist_corr = spearman_correlation_eval(X_high, X_low, n_points=n_sample, n_repetitions=n_repetitions)
    
    results = {
        "neighbor_kept_ratio": nkt,
        "global_score": gs,
        "rta": rta,
        "dist_corr": dist_corr
    }
        
    return results